-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ZarrTrace #7540
Add ZarrTrace #7540
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #7540 +/- ##
==========================================
- Coverage 92.82% 92.81% -0.01%
==========================================
Files 106 107 +1
Lines 17765 18168 +403
==========================================
+ Hits 16490 16863 +373
- Misses 1275 1305 +30
|
This is an important issue to keep track of when we'll eventually want to read the zarr store and create an |
pymc/backends/zarr.py
Outdated
_dtype = np.dtype(dtype) | ||
if np.issubdtype(_dtype, np.floating): | ||
return (np.nan, _dtype, None) | ||
elif np.issubdtype(_dtype, np.integer): | ||
return (-1_000_000, _dtype, None) | ||
elif np.issubdtype(_dtype, "bool"): | ||
return (False, _dtype, None) | ||
elif np.issubdtype(_dtype, "str"): | ||
return ("", _dtype, None) | ||
elif np.issubdtype(_dtype, "datetime64"): | ||
return (np.datetime64(0, "Y"), _dtype, None) | ||
elif np.issubdtype(_dtype, "timedelta64"): | ||
return (np.timedelta64(0, "Y"), _dtype, None) | ||
else: | ||
return (None, _dtype, numcodecs.Pickle()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question from my own ignorance, since I don't understand so much how fill values are implemented. Are we just hoping that these fill values don't actually occur in the data?
If so, this seems especially perilous for bool
😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, they are supposed to be the initialisation values for the entries. When the sampler completes its run, all entries will be filled with the correct value. Zarr just needs you to tell it what value to give to unwritten places. In the storage, these entries are never actually written, they are produced when you ask for the concrete values in the array.
The dangerous part is that xarray is interpreting fill_value
as an indicator of whether the actual value should be masked to nan. This seems to be because of the netcdf standard treats fill_value as something completely different.
To keep things as clean as possible, I’ll store the draw_idx of each chain in a separate group that should never be converted to xarray.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, that makes a lot more sense now, thanks for the explanation!
In case it's non-obvious to more than me, maybe it would be helpful to try to make this more self-evident. Perhaps by calling the function get_initial_fill_value_and_codec
, or make some little comment that the fill value is used for initialization?
Yes, therefore I would recommend not to use them for any new implementation.
Just to clarify:
It should be quite simple to implement a
Yes and No. I would say ArviZ is a first-class citizen, because
I consider 2. the primary weakness, and it's the only reason why I don't use McBackend by default. First we must find answers to:
Protocol buffers I used them because they are convenient for specifying a data structure and not having to write serialization/deserialization code for it. And they are only for the constant metadata From the Python perspective this could also be done with Is that tight integration? The important design decision is not which implementation is used to serialize/deserialize metadata, but rather to freeze and detach these (meta)data from chains, draws and stats:
|
Thanks @michaelosthege for the feedback!
I understand what the two backends for McBackend offer and that McBackend already has a test suite. Despite this, I'll try to argue in favor of writing something that's detached from McBackend.
The way I see this is that McBackend offers a signature to convert from a kind of storage (like
The key thing is that I added these groups to the zarr hierarchy, having them as
I decided to only focus on MCMC for now, and I'm trying to make
Yes, you can deserialize almost all of the contents into C++ or Rust. zarr can be readable from python, Julia, C++, rust, javascript and Java. The only content that would not be readable in other languages would come from arrays with The latter isn't a problem in my opinion because it is related exclusively to the python pymc step methods, and I detached it to its own private group in the zarr hierarchy. The former might be more problematic, but since Having said that, there are other benefits that we would get if we were to rely on zarr directly, such as:
I think that these added benefits plus the drop in maintenance costs in the long run warrant using zarr directly and not through a new backend for McBackend. |
@lucianopaz, have you done some benchmarks with this yet (in particular with S3)? I'm a bit concerned that with (1, 1, ...) chunk size that I/O will be a bottleneck. |
3206597
to
69bb2ac
Compare
No, I haven't. But I've made the chunksize customizable now via the Anyway, my long term goal is to add something like checkpoints during sampling where the trace gets dumped into a file along with the sampling state of the step methods. I think that I'll eventually make the chunks align with that, so that we don't lose samples that were drawn before the checkpoint if sampling gets terminated afterwards (before having finished). |
By the way, I've added a |
1f6d646
to
413d724
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I love this direction. I left a comment on ArviZ integration.
I also have more ideas of things that can be done to integrate better the sampling outputs with inferencedata but it might be better to address them in follow up PRs. Not having to go into the current ArviZ converter might help get this things off the ground. Many of these are around since #5160
Also, anything on ArviZ side that can help with this let me know
Better sample_stats. sample_stats doesn't necessarily need to restrict itself to having chain, draw
dimensions in all its variables. the mass matrix could also go in there and a divergence_id
even (with extra coordinate values or a multiindex to store the start and end points of divergences) which would complement the boolean diverging
variable with chain, draw
dimension.
samples in the unconstrained space. related to #6721 and to a lesser extent arviz-devs/arviz-base#8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I absolutely love this! :D
pymc/backends/zarr.py
Outdated
def setup(self, draws: int, chain: int, sampler_vars: Sequence[dict] | None): # type: ignore[override] | ||
self.chain = chain | ||
|
||
def record(self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not check the source, but I think zarr will write the whole chunk each time we set a draw here, even if that chunk is not full yet. If that is indeed the case, we should be able to speed this up a lot if draws_per_chunk is >1 if we buffer draws_per_chunk
draws, and the set values in one go.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, it would be a shame that zarr itself doesn't buffer the write until the chunk is filled.
Great! Yes, let's try to address those in other PRs.
I don't know if
I tried to stay close to how pymc is returning things right now. I agree that this could be greatly improved, but maybe we can do so in follow up iterations.
I'll try to add this. It doesn't seem to be difficult. |
0c5e73b
to
ee0a36d
Compare
I just did a little experiment with strace, to see what zarr does when we write a value to a chunk. I wrote a little script that creates the array, writes two chunks and then closes the file. With import zarr
import numpy as np
import sys
store = zarr.DirectoryStore("zarr-test.zarr")
#store = zarr.LRUStoreCache(store, max_size=2**28)
data = zarr.open(store)
foo = data.array("foo", np.zeros((0, 0, 0)), chunks=(10, 10, 1))
foo.resize((1000, 10, 1))
# Mark the position in the code to make it easier to find the correct part
print("start--", flush=True, file=sys.stderr)
foo[0, 0, 0] = 1.0
print("mid--", flush=True, file=sys.stderr)
foo[1, 0, 0] = 2.0
print("done--", flush=True, file=sys.stderr) The first write triggers this:
The second write triggers this:
So there's a lot going on for each write to the array. If I read this correctly, for the second store it actually reads the chunk from the disc, then modifies the chunk with the indexing update, writes the new chunk to a temporary file and then replaces that with the original chunk file. For the first write it skips reading in the chunk data, because there still is nothing there to read. So I think if we want to get good performance from this, we should try to combine writes, by buffering |
sample
pymc/backends/zarr.py
Outdated
self.fn = model.compile_fn( | ||
self.vars, | ||
inputs=model.value_vars, | ||
on_unused_input="ignore", | ||
borrow_vars=True, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did I mess it up. Doesn't BaseTrace.__init__
already create the fn you need?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All my comments about borrow and trust_input were based on this assumption. Note the base class does not call model.compile_fn
, it sidesteps it completely, hence why I was surprised you were changing it previously
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, the BaseTrace
for each chain (ZarrChain
) is initialized in ZarrTrace
. Now I'm sending off the fn
that is compiled in ZarrTrace
over to every chain instance to use it without having to recompile it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should do that for MultiTrace as well. I had that in mind as a follow up in #7578 where a lot of the wall time was just avoiding needless function compilations.
My early comment still stands you shouldn't call model.compile_fn
and you should (unless you don't want the speedup for some reason) have trust_input and borrow inputs and outputs to avoid deep copies needlessly. AFAICT the same logic should be applicable to the new and old traces.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
However, the reason I didn't do it for multiprocessing yet is that we have to be careful about point functions with RNGs (say with stochastic Deterministics). We actually need to copy the function and provide new shared RNGs for each function. If you're sharing self.fn
across multiple chains (and they are called in different processes) the same concern may apply here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is a related issue: #7588
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you don't copy the function and do the RNG swap it's better not to share for now, as it may mess up the RNG updates and you get crappy random draws.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is the kind of logic needed for this function copy and RNG swap: https://github.com/pymc-devs/pymc-extras/blob/c1809e8149fc89ac6eadf4bf73050ea6fe82955c/pymc_experimental/sampling/optimizations/conjugate_sampler.py#L84-L98
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the problem that the random generators inside self.fn
are shared across chains? What I mean is that, is the problem that when different chains call the same fn
, the random generator state will advance mixing up the different chains? I can see that this could happen in the context of a single process that runs multiple chains, either in threads or sequentially. But I don't see this happening in different processes.
Or is the problem that we don't manage to control the random generator state that gets called in fn
as we do with step methods? If this is the real problem, we could address it using the RandomGeneratorState
, and random_generator_from_state
and get_state_from_generator
to create copies of random generators that adhere to whatever state was inputed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Anyway, I think that #7588 really captures the core problem with what might be going on, so I don't think that it's worth opening a different issue. This PR doesn't close the issue though, but I don't think that it's intention should be to do so. I have an idea to try and tackle 7588 using the RandomGeneratorState
stuff that I developed here.
9f4c013
to
482bb9f
Compare
@ricardoV94, I think that I fixed the merge conflicts with upstream main and I've removed the code I had added to |
sample
Thanks @ricardoV94 ! I’ll merge and hope for the best |
This has big potential! Thanks for it |
Description
This PR is related to #7503. It specifically focuses on having a way to store intermediate trace results and the step methods sampling state somewhere (See task 2 of #7508).
To be honest, the current situation of the
MultiTrace
andNDArray
backends is terrible. These backend classes have inconsistent signatures across subclasses, and it's very awkward to write new backends that adhere to them.McBackend
was an attempt to make things sane again. As far as I understand,McBackend
does support ways to dump samples to disk instead of holding them in memory using theClickHouse
database. However, I found the backend a bit detached from arviz and xarray, and it seemed to be tightly linked to protocol buffers, which made it harder for me to see how I could customize stuff.These considerations brought me to the approach I'm pursuing in this PR: add a backend that uses zarr. Using zarr has the following benefits:
xarray
can read zarr stores directly making it possible to writeInferenceData
objects to disk directly almost without even having to call a converter.object
dtyped arrays using thenumcodec
package. This makes it possible use the same store to hold sample stats warning objects and step methodssampling_state
in the same place as the actual samples from the posterior.Having stated all of these considerations I intend to:
ZarrTrace
integrate well withpymc.sample
Replace theMultiTrace
andNDArray
backend defaults with their Zarr counterpartsZarrTrace
ZarrChain.record
Make it possible to load the zarr trace backend and resume sampling from it.Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7540.org.readthedocs.build/en/7540/